/*
 * This file contains the implementation of the 'einsum' function,
 * which provides an einstein-summation operation.
 *
 * Copyright (c) 2011 by Mark Wiebe (mwwiebe@gmail.com)
 * The University of British Columbia
 *
 * See LICENSE.txt for the license.
 */

#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <structmember.h>

#define NPY_NO_DEPRECATED_API NPY_API_VERSION
#define _MULTIARRAYMODULE
#include <numpy/npy_common.h>
#include <numpy/arrayobject.h>

#include <array_assign.h>   //PyArray_AssignRawScalar

#include <ctype.h>
extern "C" {
#include "convert.h"
#include "common.h"
#include "ctors.h"

#include "einsum_sumprod.h"
#include "einsum_debug.h"
}

/*
 * Parses the subscripts for one operand into an output of 'ndim'
 * labels. The resulting 'op_labels' array will have:
 *  - the ASCII code of the label for the first occurrence of a label;
 *  - the (negative) offset to the first occurrence of the label for
 *    repeated labels;
 *  - zero for broadcast dimensions, if subscripts has an ellipsis.
 * For example:
 *  - subscripts="abbcbc",  ndim=6 -> op_labels=[97, 98, -1, 99, -3, -2]
 *  - subscripts="ab...bc", ndim=6 -> op_labels=[97, 98, 0, 0, -3, 99]
 */
static int
parse_operand_subscripts(char *subscripts, int length,
                         int ndim, int iop, char *op_labels,
                         char *label_counts, int *min_label, int *max_label)
{
    int i;
    int idim = 0;
    int ellipsis = -1;

    /* Process all labels for this operand */
    for (i = 0; i < length; ++i) {
        int label = subscripts[i];

        /* A proper label for an axis. */
        if (label > 0 && isalpha(label)) {
            /* Check we don't exceed the operator dimensions. */
            if (idim >= ndim) {
                PyErr_Format(PyExc_ValueError,
                             "einstein sum subscripts string contains "
                             "too many subscripts for operand %d", iop);
                return -1;
            }

            op_labels[idim++] = label;
            if (label < *min_label) {
                *min_label = label;
            }
            if (label > *max_label) {
                *max_label = label;
            }
            label_counts[label]++;
        }
        /* The beginning of the ellipsis. */
        else if (label == '.') {
            /* Check it's a proper ellipsis. */
            if (ellipsis != -1 || i + 2 >= length
                    || subscripts[++i] != '.' || subscripts[++i] != '.') {
                PyErr_Format(PyExc_ValueError,
                             "einstein sum subscripts string contains a "
                             "'.' that is not part of an ellipsis ('...') "
                             "in operand %d", iop);
                return -1;
            }

            ellipsis = idim;
        }
        else if (label != ' ') {
            PyErr_Format(PyExc_ValueError,
                         "invalid subscript '%c' in einstein sum "
                         "subscripts string, subscripts must "
                         "be letters", (char)label);
            return -1;
        }
    }

    /* No ellipsis found, labels must match dimensions exactly. */
    if (ellipsis == -1) {
        if (idim != ndim) {
            PyErr_Format(PyExc_ValueError,
                         "operand has more dimensions than subscripts "
                         "given in einstein sum, but no '...' ellipsis "
                         "provided to broadcast the extra dimensions.");
            return -1;
        }
    }
    /* Ellipsis found, may have to add broadcast dimensions. */
    else if (idim < ndim) {
        /* Move labels after ellipsis to the end. */
        for (i = 0; i < idim - ellipsis; ++i) {
            op_labels[ndim - i - 1] = op_labels[idim - i - 1];
        }
        /* Set all broadcast dimensions to zero. */
        for (i = 0; i < ndim - idim; ++i) {
            op_labels[ellipsis + i] = 0;
        }
    }

    /*
     * Find any labels duplicated for this operand, and turn them
     * into negative offsets to the axis to merge with.
     *
     * In C, the char type may be signed or unsigned, but with
     * twos complement arithmetic the char is ok either way here, and
     * later where it matters the char is cast to a signed char.
     */
    for (idim = 0; idim < ndim - 1; ++idim) {
        int label = (signed char)op_labels[idim];
        /* If it is a proper label, find any duplicates of it. */
        if (label > 0) {
            /* Search for the next matching label. */
            char *next = (char*)memchr(op_labels + idim + 1, label, ndim - idim - 1);

            while (next != NULL) {
                /* The offset from next to op_labels[idim] (negative). */
                *next = (char)((op_labels + idim) - next);
                /* Search for the next matching label. */
                next = (char*)memchr(next + 1, label, op_labels + ndim - 1 - next);
            }
        }
    }

    return 0;
}


/*
 * Parses the subscripts for the output operand into an output that
 * includes 'ndim_broadcast' unlabeled dimensions, and returns the total
 * number of output dimensions, or -1 if there is an error. Similarly
 * to parse_operand_subscripts, the 'out_labels' array will have, for
 * each dimension:
 *  - the ASCII code of the corresponding label;
 *  - zero for broadcast dimensions, if subscripts has an ellipsis.
 */
static int
parse_output_subscripts(char *subscripts, int length,
                        int ndim_broadcast,
                        const char *label_counts, char *out_labels)
{
    int i, bdim;
    int ndim = 0;
    int ellipsis = 0;

    /* Process all the output labels. */
    for (i = 0; i < length; ++i) {
        int label = subscripts[i];

        /* A proper label for an axis. */
        if (label > 0 && isalpha(label)) {
            /* Check that it doesn't occur again. */
            if (memchr(subscripts + i + 1, label, length - i - 1) != NULL) {
                PyErr_Format(PyExc_ValueError,
                             "einstein sum subscripts string includes "
                             "output subscript '%c' multiple times",
                             (char)label);
                return -1;
            }
            /* Check that it was used in the inputs. */
            if (label_counts[label] == 0) {
                PyErr_Format(PyExc_ValueError,
                             "einstein sum subscripts string included "
                             "output subscript '%c' which never appeared "
                             "in an input", (char)label);
                return -1;
            }
            /* Check that there is room in out_labels for this label. */
            if (ndim >= NPY_MAXDIMS) {
                PyErr_Format(PyExc_ValueError,
                             "einstein sum subscripts string contains "
                             "too many subscripts in the output");
                return -1;
            }

            out_labels[ndim++] = label;
        }
        /* The beginning of the ellipsis. */
        else if (label == '.') {
            /* Check it is a proper ellipsis. */
            if (ellipsis || i + 2 >= length
                    || subscripts[++i] != '.' || subscripts[++i] != '.') {
                PyErr_SetString(PyExc_ValueError,
                                "einstein sum subscripts string "
                                "contains a '.' that is not part of "
                                "an ellipsis ('...') in the output");
                return -1;
            }
            /* Check there is room in out_labels for broadcast dims. */
            if (ndim + ndim_broadcast > NPY_MAXDIMS) {
                PyErr_Format(PyExc_ValueError,
                             "einstein sum subscripts string contains "
                             "too many subscripts in the output");
                return -1;
            }

            ellipsis = 1;
            for (bdim = 0; bdim < ndim_broadcast; ++bdim) {
                out_labels[ndim++] = 0;
            }
        }
        else if (label != ' ') {
            PyErr_Format(PyExc_ValueError,
                         "invalid subscript '%c' in einstein sum "
                         "subscripts string, subscripts must "
                         "be letters", (char)label);
            return -1;
        }
    }

    /* If no ellipsis was found there should be no broadcast dimensions. */
    if (!ellipsis && ndim_broadcast > 0) {
        PyErr_SetString(PyExc_ValueError,
                        "output has more dimensions than subscripts "
                        "given in einstein sum, but no '...' ellipsis "
                        "provided to broadcast the extra dimensions.");
        return -1;
    }

    return ndim;
}


/*
 * When there's just one operand and no reduction we can return a view
 * into 'op'.  This calculates the view and stores it in 'ret', if
 * possible.  Returns -1 on error, 0 otherwise.  Note that a 0 return
 * does not mean that a view was successfully created.
 */
static int
get_single_op_view(PyArrayObject *op, char *labels,
                   int ndim_output, char *output_labels,
                   PyArrayObject **ret)
{
    npy_intp new_strides[NPY_MAXDIMS];
    npy_intp new_dims[NPY_MAXDIMS];
    char *out_label;
    int label, i, idim, ndim, ibroadcast = 0;

    ndim = PyArray_NDIM(op);

    /* Initialize the dimensions and strides to zero */
    for (idim = 0; idim < ndim_output; ++idim) {
        new_dims[idim] = 0;
        new_strides[idim] = 0;
    }

    /* Match the labels in the operand with the output labels */
    for (idim = 0; idim < ndim; ++idim) {
        /*
         * The char type may be either signed or unsigned, we
         * need it to be signed here.
         */
        label = (signed char)labels[idim];
        /* If this label says to merge axes, get the actual label */
        if (label < 0) {
            label = labels[idim+label];
        }
        /* If the label is 0, it's an unlabeled broadcast dimension */
        if (label == 0) {
            /* The next output label that's a broadcast dimension */
            for (; ibroadcast < ndim_output; ++ibroadcast) {
                if (output_labels[ibroadcast] == 0) {
                    break;
                }
            }
            if (ibroadcast == ndim_output) {
                PyErr_SetString(PyExc_ValueError,
                        "output had too few broadcast dimensions");
                return -1;
            }
            new_dims[ibroadcast] = PyArray_DIM(op, idim);
            new_strides[ibroadcast] = PyArray_STRIDE(op, idim);
            ++ibroadcast;
        }
        else {
            /* Find the position for this dimension in the output */
            out_label = (char *)memchr(output_labels, label,
                                                    ndim_output);
            /* If it's not found, reduction -> can't return a view */
            if (out_label == NULL) {
                break;
            }
            /* Update the dimensions and strides of the output */
            i = out_label - output_labels;
            if (new_dims[i] != 0 && new_dims[i] != PyArray_DIM(op, idim)) {
                PyErr_Format(PyExc_ValueError,
                        "dimensions in single operand for collapsing "
                        "index '%c' don't match (%d != %d)",
                        label, (int)new_dims[i], (int)PyArray_DIM(op, idim));
                return -1;
            }
            new_dims[i] = PyArray_DIM(op, idim);
            new_strides[i] += PyArray_STRIDE(op, idim);
        }
    }
    /* If we processed all the input axes, return a view */
    if (idim == ndim) {
        Py_INCREF(PyArray_DESCR(op));
        *ret = (PyArrayObject *)PyArray_NewFromDescr_int(
                Py_TYPE(op), PyArray_DESCR(op),
                ndim_output, new_dims, new_strides, PyArray_DATA(op),
                PyArray_ISWRITEABLE(op) ? NPY_ARRAY_WRITEABLE : 0,
                (PyObject *)op, (PyObject *)op, (_NPY_CREATION_FLAGS)0);

        if (*ret == NULL) {
            return -1;
        }
        return 0;
    }

    /* Return success, but that we couldn't make a view */
    *ret = NULL;
    return 0;
}


/*
 * The char type may be either signed or unsigned, we need it to be
 * signed here.
 */
static int
_any_labels_are_negative(signed char *labels, int ndim)
{
    int idim;

    for (idim = 0; idim < ndim; ++idim) {
        if (labels[idim] < 0) {
            return 1;
        }
    }

    return 0;
}

/*
 * Given the labels for an operand array, returns a view of the array
 * with all repeated labels collapsed into a single dimension along
 * the corresponding diagonal. The labels are also updated to match
 * the dimensions of the new array. If no label is repeated, the
 * original array is reference increased and returned unchanged.
 */
static PyArrayObject *
get_combined_dims_view(PyArrayObject *op, int iop, char *labels)
{
    npy_intp new_strides[NPY_MAXDIMS];
    npy_intp new_dims[NPY_MAXDIMS];
    int idim, icombine;
    int icombinemap[NPY_MAXDIMS];
    int ndim = PyArray_NDIM(op);
    PyArrayObject *ret = NULL;

    /* A fast path to avoid unnecessary calculations. */
    if (!_any_labels_are_negative((signed char *)labels, ndim)) {
        Py_INCREF(op);

        return op;
    }

    /* Combine repeated labels. */
    icombine = 0;
    for(idim = 0; idim < ndim; ++idim) {
        /*
         * The char type may be either signed or unsigned, we
         * need it to be signed here.
         */
        int label = (signed char)labels[idim];
        npy_intp dim = PyArray_DIM(op, idim);
        npy_intp stride = PyArray_STRIDE(op, idim);

        /* A label seen for the first time, add it to the op view. */
        if (label >= 0) {
            /*
             * icombinemap maps dimensions in the original array to
             * their position in the combined dimensions view.
             */
            icombinemap[idim] = icombine;
            new_dims[icombine] = dim;
            new_strides[icombine] = stride;
            ++icombine;
        }
        /* A repeated label, find the original one and merge them. */
        else {
            int i = icombinemap[idim + label];

            icombinemap[idim] = -1;
            if (new_dims[i] != dim) {
                char orig_label = labels[idim + label];
                PyErr_Format(PyExc_ValueError,
                             "dimensions in operand %d for collapsing "
                             "index '%c' don't match (%d != %d)",
                             iop, orig_label, (int)new_dims[i], (int)dim);
                return NULL;
            }
            new_strides[i] += stride;
        }
    }

    /* Overwrite labels to match the new operand view. */
    for (idim = 0; idim < ndim; ++idim) {
        int i = icombinemap[idim];

        if (i >= 0) {
            labels[i] = labels[idim];
        }
    }

    /* The number of dimensions of the combined view. */
    ndim = icombine;

    /* Create a view of the operand with the compressed dimensions. */
    Py_INCREF(PyArray_DESCR(op));
    ret = (PyArrayObject *)PyArray_NewFromDescrAndBase(
            Py_TYPE(op), PyArray_DESCR(op),
            ndim, new_dims, new_strides, PyArray_DATA(op),
            PyArray_ISWRITEABLE(op) ? NPY_ARRAY_WRITEABLE : 0,
            (PyObject *)op, (PyObject *)op);

    return ret;
}

static int
prepare_op_axes(int ndim, int iop, char *labels, int *axes,
            int ndim_iter, char *iter_labels)
{
    int i, label, ibroadcast;

    ibroadcast = ndim-1;
    for (i = ndim_iter-1; i >= 0; --i) {
        label = iter_labels[i];
        /*
         * If it's an unlabeled broadcast dimension, choose
         * the next broadcast dimension from the operand.
         */
        if (label == 0) {
            while (ibroadcast >= 0 && labels[ibroadcast] != 0) {
                --ibroadcast;
            }
            /*
             * If we used up all the operand broadcast dimensions,
             * extend it with a "newaxis"
             */
            if (ibroadcast < 0) {
                axes[i] = -1;
            }
            /* Otherwise map to the broadcast axis */
            else {
                axes[i] = ibroadcast;
                --ibroadcast;
            }
        }
        /* It's a labeled dimension, find the matching one */
        else {
            char *match = (char*)memchr(labels, label, ndim);
            /* If the op doesn't have the label, broadcast it */
            if (match == NULL) {
                axes[i] = -1;
            }
            /* Otherwise use it */
            else {
                axes[i] = match - labels;
            }
        }
    }

    return 0;
}

static int
unbuffered_loop_nop1_ndim2(NpyIter *iter)
{
    npy_intp coord, shape[2], strides[2][2];
    char *ptrs[2][2], *ptr;
    sum_of_products_fn sop;
    NPY_BEGIN_THREADS_DEF;

#if NPY_EINSUM_DBG_TRACING
    NpyIter_DebugPrint(iter);
#endif
    NPY_EINSUM_DBG_PRINT("running hand-coded 1-op 2-dim loop\n");

    NpyIter_GetShape(iter, shape);
    memcpy(strides[0], NpyIter_GetAxisStrideArray(iter, 0),
                                            2*sizeof(npy_intp));
    memcpy(strides[1], NpyIter_GetAxisStrideArray(iter, 1),
                                            2*sizeof(npy_intp));
    memcpy(ptrs[0], NpyIter_GetInitialDataPtrArray(iter),
                                            2*sizeof(char *));
    memcpy(ptrs[1], ptrs[0], 2*sizeof(char*));

    sop = get_sum_of_products_function(1,
                    NpyIter_GetDescrArray(iter)[0]->type_num,
                    NpyIter_GetDescrArray(iter)[0]->elsize,
                    strides[0]);

    if (sop == NULL) {
        PyErr_SetString(PyExc_TypeError,
                    "invalid data type for einsum");
        return -1;
    }

    /* IterationNeedsAPI effectively only checks for object dtype here. */
    int needs_api = NpyIter_IterationNeedsAPI(iter);
    if (!needs_api) {
        NPY_BEGIN_THREADS_THRESHOLDED(shape[1] * shape[0]);
    }

    /*
     * Since the iterator wasn't tracking coordinates, the
     * loop provided by the iterator is in Fortran-order.
     */
    for (coord = shape[1]; coord > 0; --coord) {
        sop(1, ptrs[0], strides[0], shape[0]);

        if (needs_api && PyErr_Occurred()){
            return -1;
        }

        ptr = ptrs[1][0] + strides[1][0];
        ptrs[0][0] = ptrs[1][0] = ptr;
        ptr = ptrs[1][1] + strides[1][1];
        ptrs[0][1] = ptrs[1][1] = ptr;
    }
    NPY_END_THREADS;

    return 0;
}

static int
unbuffered_loop_nop1_ndim3(NpyIter *iter)
{
    npy_intp coords[2], shape[3], strides[3][2];
    char *ptrs[3][2], *ptr;
    sum_of_products_fn sop;
    NPY_BEGIN_THREADS_DEF;

#if NPY_EINSUM_DBG_TRACING
    NpyIter_DebugPrint(iter);
#endif
    NPY_EINSUM_DBG_PRINT("running hand-coded 1-op 3-dim loop\n");

    NpyIter_GetShape(iter, shape);
    memcpy(strides[0], NpyIter_GetAxisStrideArray(iter, 0),
                                            2*sizeof(npy_intp));
    memcpy(strides[1], NpyIter_GetAxisStrideArray(iter, 1),
                                            2*sizeof(npy_intp));
    memcpy(strides[2], NpyIter_GetAxisStrideArray(iter, 2),
                                            2*sizeof(npy_intp));
    memcpy(ptrs[0], NpyIter_GetInitialDataPtrArray(iter),
                                            2*sizeof(char *));
    memcpy(ptrs[1], ptrs[0], 2*sizeof(char*));
    memcpy(ptrs[2], ptrs[0], 2*sizeof(char*));

    sop = get_sum_of_products_function(1,
                    NpyIter_GetDescrArray(iter)[0]->type_num,
                    NpyIter_GetDescrArray(iter)[0]->elsize,
                    strides[0]);

    if (sop == NULL) {
        PyErr_SetString(PyExc_TypeError,
                    "invalid data type for einsum");
        return -1;
    }

    /* IterationNeedsAPI effectively only checks for object dtype here. */
    int needs_api = NpyIter_IterationNeedsAPI(iter);
    if (!needs_api) {
        NPY_BEGIN_THREADS_THRESHOLDED(shape[2] * shape[1] * shape[0]);
    }

    /*
     * Since the iterator wasn't tracking coordinates, the
     * loop provided by the iterator is in Fortran-order.
     */
    for (coords[1] = shape[2]; coords[1] > 0; --coords[1]) {
        for (coords[0] = shape[1]; coords[0] > 0; --coords[0]) {
            sop(1, ptrs[0], strides[0], shape[0]);

            if (needs_api && PyErr_Occurred()){
                return -1;
            }

            ptr = ptrs[1][0] + strides[1][0];
            ptrs[0][0] = ptrs[1][0] = ptr;
            ptr = ptrs[1][1] + strides[1][1];
            ptrs[0][1] = ptrs[1][1] = ptr;
        }
        ptr = ptrs[2][0] + strides[2][0];
        ptrs[0][0] = ptrs[1][0] = ptrs[2][0] = ptr;
        ptr = ptrs[2][1] + strides[2][1];
        ptrs[0][1] = ptrs[1][1] = ptrs[2][1] = ptr;
    }
    NPY_END_THREADS;

    return 0;
}

static int
unbuffered_loop_nop2_ndim2(NpyIter *iter)
{
    npy_intp coord, shape[2], strides[2][3];
    char *ptrs[2][3], *ptr;
    sum_of_products_fn sop;
    NPY_BEGIN_THREADS_DEF;

#if NPY_EINSUM_DBG_TRACING
    NpyIter_DebugPrint(iter);
#endif
    NPY_EINSUM_DBG_PRINT("running hand-coded 2-op 2-dim loop\n");

    NpyIter_GetShape(iter, shape);
    memcpy(strides[0], NpyIter_GetAxisStrideArray(iter, 0),
                                            3*sizeof(npy_intp));
    memcpy(strides[1], NpyIter_GetAxisStrideArray(iter, 1),
                                            3*sizeof(npy_intp));
    memcpy(ptrs[0], NpyIter_GetInitialDataPtrArray(iter),
                                            3*sizeof(char *));
    memcpy(ptrs[1], ptrs[0], 3*sizeof(char*));

    sop = get_sum_of_products_function(2,
                    NpyIter_GetDescrArray(iter)[0]->type_num,
                    NpyIter_GetDescrArray(iter)[0]->elsize,
                    strides[0]);

    if (sop == NULL) {
        PyErr_SetString(PyExc_TypeError,
                    "invalid data type for einsum");
        return -1;
    }

    /* IterationNeedsAPI effectively only checks for object dtype here. */
    int needs_api = NpyIter_IterationNeedsAPI(iter);
    if (!needs_api) {
        NPY_BEGIN_THREADS_THRESHOLDED(shape[1] * shape[0]);
    }

    /*
     * Since the iterator wasn't tracking coordinates, the
     * loop provided by the iterator is in Fortran-order.
     */
    for (coord = shape[1]; coord > 0; --coord) {
        sop(2, ptrs[0], strides[0], shape[0]);

        if(needs_api && PyErr_Occurred()){
            return -1;
        }

        ptr = ptrs[1][0] + strides[1][0];
        ptrs[0][0] = ptrs[1][0] = ptr;
        ptr = ptrs[1][1] + strides[1][1];
        ptrs[0][1] = ptrs[1][1] = ptr;
        ptr = ptrs[1][2] + strides[1][2];
        ptrs[0][2] = ptrs[1][2] = ptr;
    }
    NPY_END_THREADS;

    return 0;
}

static int
unbuffered_loop_nop2_ndim3(NpyIter *iter)
{
    npy_intp coords[2], shape[3], strides[3][3];
    char *ptrs[3][3], *ptr;
    sum_of_products_fn sop;
    NPY_BEGIN_THREADS_DEF;

#if NPY_EINSUM_DBG_TRACING
    NpyIter_DebugPrint(iter);
#endif
    NPY_EINSUM_DBG_PRINT("running hand-coded 2-op 3-dim loop\n");

    NpyIter_GetShape(iter, shape);
    memcpy(strides[0], NpyIter_GetAxisStrideArray(iter, 0),
                                            3*sizeof(npy_intp));
    memcpy(strides[1], NpyIter_GetAxisStrideArray(iter, 1),
                                            3*sizeof(npy_intp));
    memcpy(strides[2], NpyIter_GetAxisStrideArray(iter, 2),
                                            3*sizeof(npy_intp));
    memcpy(ptrs[0], NpyIter_GetInitialDataPtrArray(iter),
                                            3*sizeof(char *));
    memcpy(ptrs[1], ptrs[0], 3*sizeof(char*));
    memcpy(ptrs[2], ptrs[0], 3*sizeof(char*));

    sop = get_sum_of_products_function(2,
                    NpyIter_GetDescrArray(iter)[0]->type_num,
                    NpyIter_GetDescrArray(iter)[0]->elsize,
                    strides[0]);

    if (sop == NULL) {
        PyErr_SetString(PyExc_TypeError,
                    "invalid data type for einsum");
        return -1;
    }

    /* IterationNeedsAPI effectively only checks for object dtype here. */
    int needs_api = NpyIter_IterationNeedsAPI(iter);
    if (!needs_api) {
        NPY_BEGIN_THREADS_THRESHOLDED(shape[2] * shape[1] * shape[0]);
    }

    /*
     * Since the iterator wasn't tracking coordinates, the
     * loop provided by the iterator is in Fortran-order.
     */
    for (coords[1] = shape[2]; coords[1] > 0; --coords[1]) {
        for (coords[0] = shape[1]; coords[0] > 0; --coords[0]) {
            sop(2, ptrs[0], strides[0], shape[0]);

            if(needs_api && PyErr_Occurred()){
                return -1;
            }

            ptr = ptrs[1][0] + strides[1][0];
            ptrs[0][0] = ptrs[1][0] = ptr;
            ptr = ptrs[1][1] + strides[1][1];
            ptrs[0][1] = ptrs[1][1] = ptr;
            ptr = ptrs[1][2] + strides[1][2];
            ptrs[0][2] = ptrs[1][2] = ptr;
        }
        ptr = ptrs[2][0] + strides[2][0];
        ptrs[0][0] = ptrs[1][0] = ptrs[2][0] = ptr;
        ptr = ptrs[2][1] + strides[2][1];
        ptrs[0][1] = ptrs[1][1] = ptrs[2][1] = ptr;
        ptr = ptrs[2][2] + strides[2][2];
        ptrs[0][2] = ptrs[1][2] = ptrs[2][2] = ptr;
    }
    NPY_END_THREADS;

    return 0;
}


/*NUMPY_API
 * This function provides summation of array elements according to
 * the Einstein summation convention.  For example:
 *  - trace(a)        -> einsum("ii", a)
 *  - transpose(a)    -> einsum("ji", a)
 *  - multiply(a,b)   -> einsum(",", a, b)
 *  - inner(a,b)      -> einsum("i,i", a, b)
 *  - outer(a,b)      -> einsum("i,j", a, b)
 *  - matvec(a,b)     -> einsum("ij,j", a, b)
 *  - matmat(a,b)     -> einsum("ij,jk", a, b)
 *
 * subscripts: The string of subscripts for einstein summation.
 * nop:        The number of operands
 * op_in:      The array of operands
 * dtype:      Either NULL, or the data type to force the calculation as.
 * order:      The order for the calculation/the output axes.
 * casting:    What kind of casts should be permitted.
 * out:        Either NULL, or an array into which the output should be placed.
 *
 * By default, the labels get placed in alphabetical order
 * at the end of the output. So, if c = einsum("i,j", a, b)
 * then c[i,j] == a[i]*b[j], but if c = einsum("j,i", a, b)
 * then c[i,j] = a[j]*b[i].
 *
 * Alternatively, you can control the output order or prevent
 * an axis from being summed/force an axis to be summed by providing
 * indices for the output. This allows us to turn 'trace' into
 * 'diag', for example.
 *  - diag(a)         -> einsum("ii->i", a)
 *  - sum(a, axis=0)  -> einsum("i...->", a)
 *
 * Subscripts at the beginning and end may be specified by
 * putting an ellipsis "..." in the middle.  For example,
 * the function einsum("i...i", a) takes the diagonal of
 * the first and last dimensions of the operand, and
 * einsum("ij...,jk...->ik...") takes the matrix product using
 * the first two indices of each operand instead of the last two.
 *
 * When there is only one operand, no axes being summed, and
 * no output parameter, this function returns a view
 * into the operand instead of making a copy.
 */
NPY_NO_EXPORT PyArrayObject *
PyArray_EinsteinSum(char *subscripts, npy_intp nop,
                    PyArrayObject **op_in,
                    PyArray_Descr *dtype,
                    NPY_ORDER order, NPY_CASTING casting,
                    PyArrayObject *out)
{
    int iop, label, min_label = 127, max_label = 0;
    char label_counts[128];
    char op_labels[NPY_MAXARGS][NPY_MAXDIMS];
    char output_labels[NPY_MAXDIMS], *iter_labels;
    int idim, ndim_output, ndim_broadcast, ndim_iter;

    PyArrayObject *op[NPY_MAXARGS], *ret = NULL;
    PyArray_Descr *op_dtypes_array[NPY_MAXARGS], **op_dtypes;

    int op_axes_arrays[NPY_MAXARGS][NPY_MAXDIMS];
    int *op_axes[NPY_MAXARGS];
    npy_uint32 iter_flags, op_flags[NPY_MAXARGS];

    NpyIter *iter = NULL;
    sum_of_products_fn sop;
    npy_intp *stride;

    /* nop+1 (+1 is for the output) must fit in NPY_MAXARGS */
    if (nop >= NPY_MAXARGS) {
        PyErr_SetString(PyExc_ValueError,
                    "too many operands provided to einstein sum function");
        return NULL;
    }
    else if (nop < 1) {
        PyErr_SetString(PyExc_ValueError,
                    "not enough operands provided to einstein sum function");
        return NULL;
    }

    /* Parse the subscripts string into label_counts and op_labels */
    memset(label_counts, 0, sizeof(label_counts));
    for (iop = 0; iop < nop; ++iop) {
        int length = (int)strcspn(subscripts, ",-");

        if (iop == nop-1 && subscripts[length] == ',') {
            PyErr_SetString(PyExc_ValueError,
                        "more operands provided to einstein sum function "
                        "than specified in the subscripts string");
            return NULL;
        }
        else if(iop < nop-1 && subscripts[length] != ',') {
            PyErr_SetString(PyExc_ValueError,
                        "fewer operands provided to einstein sum function "
                        "than specified in the subscripts string");
            return NULL;
        }

        if (parse_operand_subscripts(subscripts, length,
                        PyArray_NDIM(op_in[iop]),
                        iop, op_labels[iop], label_counts,
                        &min_label, &max_label) < 0) {
            return NULL;
        }

        /* Move subscripts to the start of the labels for the next op */
        subscripts += length;
        if (iop < nop-1) {
            subscripts++;
        }
    }

    /*
     * Find the number of broadcast dimensions, which is the maximum
     * number of labels == 0 in an op_labels array.
     */
    ndim_broadcast = 0;
    for (iop = 0; iop < nop; ++iop) {
        npy_intp count_zeros = 0;
        int ndim;
        char *labels = op_labels[iop];

        ndim = PyArray_NDIM(op_in[iop]);
        for (idim = 0; idim < ndim; ++idim) {
            if (labels[idim] == 0) {
                ++count_zeros;
            }
        }

        if (count_zeros > ndim_broadcast) {
            ndim_broadcast = count_zeros;
        }
    }

    /*
     * If there is no output signature, fill output_labels and ndim_output
     * using each label that appeared once, in alphabetical order.
     */
    if (subscripts[0] == '\0') {
        /* If no output was specified, always broadcast left, as usual. */
        for (ndim_output = 0; ndim_output < ndim_broadcast; ++ndim_output) {
            output_labels[ndim_output] = 0;
        }
        for (label = min_label; label <= max_label; ++label) {
            if (label_counts[label] == 1) {
                if (ndim_output < NPY_MAXDIMS) {
                    output_labels[ndim_output++] = label;
                }
                else {
                    PyErr_SetString(PyExc_ValueError,
                                "einstein sum subscript string has too many "
                                "distinct labels");
                    return NULL;
                }
            }
        }
    }
    else {
        if (subscripts[0] != '-' || subscripts[1] != '>') {
            PyErr_SetString(PyExc_ValueError,
                        "einstein sum subscript string does not "
                        "contain proper '->' output specified");
            return NULL;
        }
        subscripts += 2;

        /* Parse the output subscript string. */
        ndim_output = parse_output_subscripts(subscripts, strlen(subscripts),
                                        ndim_broadcast, label_counts,
                                        output_labels);
        if (ndim_output < 0) {
            return NULL;
        }
    }

    if (out != NULL && PyArray_NDIM(out) != ndim_output) {
        PyErr_Format(PyExc_ValueError,
                "out parameter does not have the correct number of "
                "dimensions, has %d but should have %d",
                (int)PyArray_NDIM(out), (int)ndim_output);
        return NULL;
    }

    /*
     * If there's just one operand and no output parameter,
     * first try remapping the axes to the output to return
     * a view instead of a copy.
     */
    if (nop == 1 && out == NULL) {
        ret = NULL;

        if (get_single_op_view(op_in[0], op_labels[0], ndim_output,
                               output_labels, &ret) < 0) {
            return NULL;
        }

        if (ret != NULL) {
            return ret;
        }
    }

    /* Set all the op references to NULL */
    for (iop = 0; iop < nop; ++iop) {
        op[iop] = NULL;
    }

    /*
     * Process all the input ops, combining dimensions into their
     * diagonal where specified.
     */
    for (iop = 0; iop < nop; ++iop) {
        char *labels = op_labels[iop];

        op[iop] = get_combined_dims_view(op_in[iop], iop, labels);
        if (op[iop] == NULL) {
            goto fail;
        }
    }

    /* Set the output op */
    op[nop] = out;

    /*
     * Set up the labels for the iterator (output + combined labels).
     * Can just share the output_labels memory, because iter_labels
     * is output_labels with some more labels appended.
     */
    iter_labels = output_labels;
    ndim_iter = ndim_output;
    for (label = min_label; label <= max_label; ++label) {
        if (label_counts[label] > 0 &&
                memchr(output_labels, label, ndim_output) == NULL) {
            if (ndim_iter >= NPY_MAXDIMS) {
                PyErr_SetString(PyExc_ValueError,
                            "too many subscripts in einsum");
                goto fail;
            }
            iter_labels[ndim_iter++] = label;
        }
    }

    /* Set up the op_axes for the iterator */
    for (iop = 0; iop < nop; ++iop) {
        op_axes[iop] = op_axes_arrays[iop];

        if (prepare_op_axes(PyArray_NDIM(op[iop]), iop, op_labels[iop],
                    op_axes[iop], ndim_iter, iter_labels) < 0) {
            goto fail;
        }
    }

    /* Set up the op_dtypes if dtype was provided */
    if (dtype == NULL) {
        op_dtypes = NULL;
    }
    else {
        op_dtypes = op_dtypes_array;
        for (iop = 0; iop <= nop; ++iop) {
            op_dtypes[iop] = dtype;
        }
    }

    /* Set the op_axes for the output */
    op_axes[nop] = op_axes_arrays[nop];
    for (idim = 0; idim < ndim_output; ++idim) {
        op_axes[nop][idim] = idim;
    }
    for (idim = ndim_output; idim < ndim_iter; ++idim) {
        op_axes[nop][idim] = NPY_ITER_REDUCTION_AXIS(-1);
    }

    /* Set the iterator per-op flags */

    for (iop = 0; iop < nop; ++iop) {
        op_flags[iop] = NPY_ITER_READONLY|
                        NPY_ITER_NBO|
                        NPY_ITER_ALIGNED;
    }
    op_flags[nop] = NPY_ITER_READWRITE|
                    NPY_ITER_NBO|
                    NPY_ITER_ALIGNED|
                    NPY_ITER_ALLOCATE;
    /*
     * Note: We skip GROWINNER here because this gives a partially stable
     * summation for float64.  Pairwise summation would be better.
     */
    iter_flags = NPY_ITER_EXTERNAL_LOOP|
            NPY_ITER_BUFFERED|
            NPY_ITER_DELAY_BUFALLOC|
            NPY_ITER_REFS_OK|
            NPY_ITER_ZEROSIZE_OK;
    if (out != NULL) {
        iter_flags |= NPY_ITER_COPY_IF_OVERLAP;
    }
    if (dtype == NULL) {
        iter_flags |= NPY_ITER_COMMON_DTYPE;
    }

    /* Allocate the iterator */
    iter = NpyIter_AdvancedNew(nop+1, op, iter_flags, order, casting, op_flags,
                               op_dtypes, ndim_iter, op_axes, NULL, 0);

    if (iter == NULL) {
        goto fail;
    }

    /* Initialize the output to all zeros or None*/
    ret = NpyIter_GetOperandArray(iter)[nop];
    if (PyArray_AssignZero(ret, NULL) < 0) {
        goto fail;
    }

    /***************************/
    /*
     * Acceleration for some specific loop structures. Note
     * that with axis coalescing, inputs with more dimensions can
     * be reduced to fit into these patterns.
     */
    if (!NpyIter_RequiresBuffering(iter)) {
        int ndim = NpyIter_GetNDim(iter);
        switch (nop) {
            case 1:
                if (ndim == 2) {
                    if (unbuffered_loop_nop1_ndim2(iter) < 0) {
                        goto fail;
                    }
                    goto finish;
                }
                else if (ndim == 3) {
                    if (unbuffered_loop_nop1_ndim3(iter) < 0) {
                        goto fail;
                    }
                    goto finish;
                }
                break;
            case 2:
                if (ndim == 2) {
                    if (unbuffered_loop_nop2_ndim2(iter) < 0) {
                        goto fail;
                    }
                    goto finish;
                }
                else if (ndim == 3) {
                    if (unbuffered_loop_nop2_ndim3(iter) < 0) {
                        goto fail;
                    }
                    goto finish;
                }
                break;
        }
    }
    /***************************/

    if (NpyIter_Reset(iter, NULL) != NPY_SUCCEED) {
        goto fail;
    }

    /*
     * Get an inner loop function, specializing it based on
     * the strides that are fixed for the whole loop.
     */
    stride = NpyIter_GetInnerStrideArray(iter);

    sop = get_sum_of_products_function(nop,
                        NpyIter_GetDescrArray(iter)[0]->type_num,
                        NpyIter_GetDescrArray(iter)[0]->elsize,
                        stride);

#if NPY_EINSUM_DBG_TRACING
    NpyIter_DebugPrint(iter);
#endif

    /* Finally, the main loop */
    if (sop == NULL) {
        PyErr_SetString(PyExc_TypeError,
                    "invalid data type for einsum");
    }
    else if (NpyIter_GetIterSize(iter) != 0) {
        NpyIter_IterNextFunc *iternext;
        char **dataptr;
        npy_intp *countptr;
        NPY_BEGIN_THREADS_DEF;

        iternext = NpyIter_GetIterNext(iter, NULL);
        if (iternext == NULL) {
            goto fail;
        }
        dataptr = NpyIter_GetDataPtrArray(iter);
        countptr = NpyIter_GetInnerLoopSizePtr(iter);
        /* IterationNeedsAPI additionally checks for object dtype here. */
        int needs_api = NpyIter_IterationNeedsAPI(iter);
        if (!needs_api) {
            NPY_BEGIN_THREADS_THRESHOLDED(NpyIter_GetIterSize(iter));
        }

        NPY_EINSUM_DBG_PRINT("Einsum loop\n");
        do {
            sop(nop, dataptr, stride, *countptr);
        } while (!(needs_api && PyErr_Occurred()) && iternext(iter));
        NPY_END_THREADS;

        /* If the API was needed, it may have thrown an error */
        if (needs_api && PyErr_Occurred()) {
            goto fail;
        }
    }

finish:
    if (out != NULL) {
        ret = out;
    }
    Py_INCREF(ret);

    NpyIter_Deallocate(iter);
    for (iop = 0; iop < nop; ++iop) {
        Py_DECREF(op[iop]);
    }

    return ret;

fail:
    NpyIter_Deallocate(iter);
    for (iop = 0; iop < nop; ++iop) {
        Py_XDECREF(op[iop]);
    }

    return NULL;
}
